from cvfo_mask import CvFo
from tqdm import tqdm
import paddle
import numpy as np
import pandas as pd

block_size = 8
voc = pd.read_pickle("voc.pandas_pickle")
data = pd.read_pickle("dataset.pandas_pickle")

t = paddle.to_tensor(
    [voc.loc[voc["voc"] == "<|p_{}|>".format(i)].values[0][1:-1].astype("int") for i in range(0, 128)]).T.unsqueeze(0)
model = CvFo(len(voc), 256, 6, block_size)
np.random.shuffle(data)
model.load_dict(paddle.load("cvfo_model.pdparams"))
model.eval()
batch_data = paddle.to_tensor(data[7]).astype('int64').reshape([1, 3, -1])
star_seq = batch_data[:, :2, :9]
star_index = star_seq.shape[-1] - 1


if star_seq.shape[-1] % block_size != 0:
    star_seq = paddle.concat(
        [star_seq, (len(voc) - 1) * paddle.ones([1, 2, 8 - star_seq.shape[-1] % block_size]).astype("int64")],
        axis=-1)
out = model.sample(star_seq,t)
out = [ voc.loc[voc["output_voc"].astype("int") ==i,"voc"].values[0]  for i in paddle.argmax(out, axis=-1)[0].numpy()]
print("".join(out))

# print([voc.loc[voc["output_voc"].astype("int") == i, "voc"].values[0] for i in out_list])
# ['<|aos|>', '芳', '岁', '归', '人', '嗟', '转', '蓬', '，', '<|p_65|>', '，', '<|p_65|>', '，', '<|p_113|>', '杨', '柳', '，', '<|p_66|>', '乳', '熟', '，', '<|p_83|>', '践', '，', '<|p_83|>', '默', '，', '<|p_83|>', '雨', '后', '，', '<|p_83|>', '雨', '，', '<|p_84|>', '雨', '后', '，', '<|p_83|>', '何', '处', '，', '<|p_70|>', '毒', '，', '<|p_83|>', '怅', '，', '<|p_70|>', '颭', '，', '<|p_82|>', '木', '无', '复', '，', '<|p_70|>', '徼', '，']
# ['<|aos|>', '头', '风', '目', '眩', '乘', '衰', '老', '，', '<|p_65|>', '，', '<|p_65|>', '，', '<|p_97|>', '，', '<|p_113|>', '娇', '不', '知', '何', '如', '此', '生', '不', '如', '何', '如', '此', '时', '。', '<|p_98|>', '不', '知', '。', '<|p_20|>', '包', '伏', '波', '。', '<|p_20|>', '包', '。', '<|p_70|>', '狡', '，', '<|p_83|>', '雪', '，', '<|p_7|>', '仙', '境', '，', '<|p_99|>', '<|p_100|>', '<|p_101|>', '正', '，', '<|p_81|>', '不']
# ['<|aos|>', '圣', '哲', '符', '休', '运', '，', '<|p_7|>', '伊', '水', '滨', '。', '<|p_50|>', '一', '生', '。', '<|p_98|>', '君', '家', '。', '<|p_98|>', 'n', '。', '<|p_114|>', '<|p_115|>', '纛', '则', '纷', '纷', '笏', '略', '，', '<|p_5|>', '大', '梁', '王', '妹', '，', '<|p_115|>', '<|p_116|>', '<|p_117|>', '入', '，', '<|p_18|>', '此', '地', '，', '<|p_6|>', '仙', '掌', '中', '。', '<|p_19|>', '风', '吹', '，', '<|p_54|>', '农', '。']
